"""This module contains vulnerability types, Enums, nodes and helpers."""

import json
from enum import Enum
from collections import namedtuple

from ..core.node_types import YieldNode


class VulnerabilityType(Enum):
    FALSE = 0
    SANITISED = 1
    TRUE = 2
    UNKNOWN = 3


def vuln_factory(vulnerability_type):
    if vulnerability_type == VulnerabilityType.UNKNOWN:
        return UnknownVulnerability
    elif vulnerability_type == VulnerabilityType.SANITISED:
        return SanitisedVulnerability
    else:
        return Vulnerability


def _get_reassignment_str(reassignment_nodes):
    reassignments = ''
    if reassignment_nodes:
        reassignments += '\nReassigned in:\n\t'
        reassignments += '\n\t'.join([
            'File: ' + node.path + '\n' +
            '\t > Line ' + str(node.line_number) + ': ' + node.label
            for node in reassignment_nodes
        ])
    return reassignments


class Vulnerability():
    def __init__(
        self,
        source,
        source_trigger_word,
        sink,
        sink_trigger_word,
        reassignment_nodes
    ):
        """Set source and sink information."""
        self.source = source
        self.source_trigger_word = source_trigger_word
        self.sink = sink
        self.sink_trigger_word = sink_trigger_word

        self.reassignment_nodes = reassignment_nodes
        self._remove_non_propagating_yields()

    def _remove_non_propagating_yields(self):
        """Remove yield with no variables e.g. `yield 123` and plain `yield` from vulnerability."""
        for node in list(self.reassignment_nodes):
            if isinstance(node, YieldNode) and len(node.right_hand_side_variables) == 1:
                self.reassignment_nodes.remove(node)

    def __str__(self):
        """Pretty printing of a vulnerability."""
        reassigned_str = _get_reassignment_str(self.reassignment_nodes)
        return (
            'File: {}\n'
            ' > User input at line {}, source "{}":\n'
            '\t {}{}\nFile: {}\n'
            ' > reaches line {}, sink "{}":\n'
            '\t{}'.format(
                self.source.path,
                self.source.line_number, self.source_trigger_word,
                self.source.label, reassigned_str, self.sink.path,
                self.sink.line_number, self.sink_trigger_word,
                self.sink.label
            )
        )

    def as_dict(self):
        return {
            'source': self.source.as_dict(),
            'source_trigger_word': self.source_trigger_word,
            'sink': self.sink.as_dict(),
            'sink_trigger_word': self.sink_trigger_word,
            'type': self.__class__.__name__,
            'reassignment_nodes': [node.as_dict() for node in self.reassignment_nodes]
        }


class SanitisedVulnerability(Vulnerability):
    def __init__(
        self,
        confident,
        sanitiser,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.confident = confident
        self.sanitiser = sanitiser

    def __str__(self):
        """Pretty printing of a vulnerability."""
        return (
            super().__str__() +
            '\nThis vulnerability is ' +
            ('' if self.confident else 'potentially ') +
            'sanitised by: ' +
            str(self.sanitiser)
        )

    def as_dict(self):
        output = super().as_dict()
        output['sanitiser'] = self.sanitiser.as_dict()
        output['confident'] = self.confident
        return output


class UnknownVulnerability(Vulnerability):
    def __init__(
        self,
        unknown_assignment,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.unknown_assignment = unknown_assignment

    def as_dict(self):
        output = super().as_dict()
        output['unknown_assignment'] = self.unknown_assignment.as_dict()
        return output

    def __str__(self):
        """Pretty printing of a vulnerability."""
        return (
            super().__str__() +
            '\nThis vulnerability is unknown due to: ' +
            str(self.unknown_assignment)
        )


Sanitiser = namedtuple(
    'Sanitiser',
    (
        'trigger_word',
        'cfg_node'
    )
)


Triggers = namedtuple(
    'Triggers',
    (
        'sources',
        'sinks',
        'sanitiser_dict'
    )
)


class TriggerNode():
    def __init__(
        self,
        trigger,
        cfg_node,
        secondary_nodes=[]
    ):
        self.trigger = trigger
        self.cfg_node = cfg_node
        self.secondary_nodes = secondary_nodes

    @property
    def trigger_word(self):
        return self.trigger.trigger_word

    @property
    def sanitisers(self):
        return self.trigger.sanitisers if hasattr(self.trigger, 'sanitisers') else []

    def append(self, cfg_node):
        if not cfg_node == self.cfg_node:
            if self.secondary_nodes and cfg_node not in self.secondary_nodes:
                self.secondary_nodes.append(cfg_node)
            elif not self.secondary_nodes:
                self.secondary_nodes = [cfg_node]

    def __repr__(self):
        output = 'TriggerNode('

        if self.trigger_word:
            output = '{} trigger_word is {}, '.format(
                output,
                self.trigger_word
            )

        return (
            output +
            'sanitisers are {}, '.format(self.sanitisers) +
            'cfg_node is {})\n'.format(self.cfg_node)
        )


def get_vulnerabilities_not_in_baseline(
    vulnerabilities,
    baseline_file
):
    baseline = json.load(open(baseline_file))
    output = list()
    for vuln in vulnerabilities:
        if vuln.as_dict() not in baseline['vulnerabilities']:
            output.append(vuln)
    return(output)
